// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

// This program is the generator for the gRPC service wrapper types in the
// parent directory. It's not suitable for any other use.
//
// This makes various assumptions about how the protobuf compiler and
// gRPC stub generators produce code. If those significantly change in future
// then this will probably break.
package main

import (
	"bytes"
	"fmt"
	"go/format"
	"go/types"
	"log"
	"os"
	"path/filepath"
	"regexp"
	"strings"

	"golang.org/x/tools/go/packages"
)

var protobufPkgs = map[string]string{
	"dependencies": "github.com/hashicorp/terraform/internal/rpcapi/terraform1/dependencies",
	"stacks":       "github.com/hashicorp/terraform/internal/rpcapi/terraform1/stacks",
	"packages":     "github.com/hashicorp/terraform/internal/rpcapi/terraform1/packages",
}

var additionalImportsByName = map[string]string{
	"dependencies": `"google.golang.org/grpc"`,
	"stacks":       `"google.golang.org/grpc"`,
}

func main() {
	for shortName, pkgName := range protobufPkgs {
		cfg := &packages.Config{
			Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedFiles,
		}
		pkgs, err := packages.Load(cfg, pkgName)
		if err != nil {
			log.Fatalf("can't load the protobuf/gRPC proxy package: %s", err)
		}
		if len(pkgs) != 1 {
			log.Fatalf("wrong number of packages found")
		}
		pkg := pkgs[0]
		if pkg.TypesInfo == nil {
			log.Fatalf("types info not available")
		}
		if len(pkg.GoFiles) < 1 {
			log.Fatalf("no files included in package")
		}

		// We assume that our output directory is sibling to the directory
		// containing the protobuf specification.
		outDir := filepath.Join(filepath.Dir(pkg.GoFiles[0]), "../../dynrpcserver")

	Types:
		for _, obj := range pkg.TypesInfo.Defs {
			typ, ok := obj.(*types.TypeName)
			if !ok {
				continue
			}
			underTyp := typ.Type().Underlying()
			iface, ok := underTyp.(*types.Interface)
			if !ok {
				continue
			}
			if !strings.HasSuffix(typ.Name(), "Server") || typ.Name() == "SetupServer" {
				// Doesn't look like a generated gRPC server interface
				continue
			}

			// The interfaces used for streaming requests/responses unfortunately
			// also have a "Server" suffix in the generated Go code, and so
			// we need to detect those more surgically by noticing that they
			// have grpc.ServerStream embedded inside.
			for i := 0; i < iface.NumEmbeddeds(); i++ {
				emb, ok := iface.EmbeddedType(i).(*types.Named)
				if !ok {
					continue
				}
				pkg := emb.Obj().Pkg().Path()
				name := emb.Obj().Name()
				if pkg == "google.golang.org/grpc" && name == "ServerStream" {
					continue Types
				}
			}

			// If we get here then what we're holding _seems_ to be a gRPC
			// server interface, and so we'll generate a dynamic initialization
			// wrapper for it.

			ifaceName := typ.Name()
			baseName := strings.TrimSuffix(ifaceName, "Server")
			filename := toFilenameCase(baseName) + ".go"
			absFilename := filepath.Join(outDir, filename)

			if regexp.MustCompile("^Unsafe").MatchString(ifaceName) {
				// This isn't a gRPC server interface, so skip it.
				//
				// This is an interface that's intended to be embedded to help users to meet requirements for Unimplemented servers.
				// See:
				// > Docs: https://pkg.go.dev/google.golang.org/grpc/cmd/protoc-gen-go-grpc#readme-future-proofing-services
				// > PR for Unsafe interfaces: https://github.com/grpc/grpc-go/pull/3911
				continue Types
			}

			var buf bytes.Buffer

			fmt.Fprintf(&buf, `// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

// Code generated by ./generator. DO NOT EDIT.
`)

			fmt.Fprintf(&buf, `package dynrpcserver

			import (
				"context"
				"sync"

				%s

				%s %q
			)

		`, additionalImportsByName[shortName], shortName, pkg)
			fmt.Fprintf(&buf, "type %s struct {\n", baseName)
			fmt.Fprintf(&buf, "impl %s.%s\n", shortName, ifaceName)
			fmt.Fprintln(&buf, "mu sync.RWMutex")

			unimplementedServerInterface := fmt.Sprintf("%s.Unimplemented%s", shortName, ifaceName) // UnimplementedFoobarServer struct name that's generated from the proto file.
			unimplementedServerMethod := fmt.Sprintf("mustEmbedUnimplemented%s", ifaceName)         // Name of the method implemented on UnimplementedFoobarServer.
			fmt.Fprintln(&buf, unimplementedServerInterface)                                        // Embed UnimplementedFoobarServer struct into the struct we're generating.

			buf.WriteString("}\n\n")

			fmt.Fprintf(&buf, "var _ %s.%s = (*%s)(nil)\n\n", shortName, ifaceName, baseName)

			fmt.Fprintf(&buf, "func New%sStub() *%s {\n", baseName, baseName)
			fmt.Fprintf(&buf, "return &%s{}\n", baseName)
			fmt.Fprintf(&buf, "}\n\n")

			for i := 0; i < iface.NumMethods(); i++ {
				method := iface.Method(i)

				if method.Name() == unimplementedServerMethod {
					// Code for this method doesn't need to be generated.
					// The method is present via embedding, see use of `unimplementedServerInterface` above.
					continue
				}

				sig := method.Type().(*types.Signature)

				fmt.Fprintf(&buf, "func (s *%s) %s(", baseName, method.Name())
				for i := 0; i < sig.Params().Len(); i++ {
					param := sig.Params().At(i)

					// The generated interface types don't include parameter names
					// and so we just use synthetic parameter names here.
					name := fmt.Sprintf("a%d", i)
					genType := typeRef(param.Type().String(), shortName, pkgName)

					if i > 0 {
						buf.WriteString(", ")
					}
					buf.WriteString(name)
					buf.WriteString(" ")
					buf.WriteString(genType)
				}
				fmt.Fprintf(&buf, ")")
				if sig.Results().Len() > 1 {
					buf.WriteString("(")
				}
				for i := 0; i < sig.Results().Len(); i++ {
					result := sig.Results().At(i)
					genType := typeRef(result.Type().String(), shortName, pkgName)
					if i > 0 {
						buf.WriteString(", ")
					}
					buf.WriteString(genType)
				}
				if sig.Results().Len() > 1 {
					buf.WriteString(")")
				}
				switch n := sig.Results().Len(); n {
				case 1:
					fmt.Fprintf(&buf, ` {
				impl, err := s.realRPCServer()
				if err != nil {
					return err
				}
			`)
				case 2:
					fmt.Fprintf(&buf, ` {
				impl, err := s.realRPCServer()
				if err != nil {
					return nil, err
				}
			`)
				default:
					log.Fatalf("don't know how to make a stub for method with %d results", n)
				}
				fmt.Fprintf(&buf, "return impl.%s(", method.Name())
				for i := 0; i < sig.Params().Len(); i++ {
					if i > 0 {
						buf.WriteString(", ")
					}
					fmt.Fprintf(&buf, "a%d", i)
				}
				fmt.Fprintf(&buf, ")\n}\n\n")
			}

			fmt.Fprintf(&buf, `
			func (s *%s) ActivateRPCServer(impl %s.%s) {
				s.mu.Lock()
				s.impl = impl
				s.mu.Unlock()
			}

			func (s *%s) realRPCServer() (%s.%s, error) {
				s.mu.RLock()
				impl := s.impl
				s.mu.RUnlock()
				if impl == nil {
					return nil, unavailableErr
				}
				return impl, nil
			}
		`, baseName, shortName, ifaceName, baseName, shortName, ifaceName)

			src, err := format.Source(buf.Bytes())
			if err != nil {
				//log.Fatalf("formatting %s: %s", filename, err)
				src = buf.Bytes()
			}
			f, err := os.Create(absFilename)
			if err != nil {
				log.Fatal(err)
			}
			_, err = f.Write(src)
			if err != nil {
				log.Fatalf("writing %s: %s", filename, err)
			}

		}
	}
}

func typeRef(fullType, name, pkg string) string {
	// The following is specialized to only the parameter types
	// we typically expect to see in a server interface. This
	// might need extra rules if we step outside the design idiom
	// we've used for these services so far.

	// Identifies generic types from google.golang.org/grpc module with 1+ type arguments.
	grpcGenericRe := regexp.MustCompile(`^google\.golang\.org\/grpc\.\w+\[[\w\.\/,\s]+\]`)

	switch {
	case fullType == "context.Context" || fullType == "error":
		return fullType
	case fullType == "interface{}" || fullType == "any":
		return "any"
	case strings.HasPrefix(fullType, "*"+pkg+"."):
		return "*" + name + "." + fullType[len(pkg)+2:]
	case strings.HasPrefix(fullType, pkg+"."):
		return name + "." + fullType[len(pkg)+1:]
	case grpcGenericRe.MatchString(fullType):
		// Handling use of google.golang.org/grpc.Foobar[T...] generic types.
		// Example 1: google.golang.org/grpc.ServerStreamingServer[github.com/hashicorp/terraform/internal/rpcapi/terraform1/dependencies.BuildProviderPluginCache_Event]
		// Example 2: google.golang.org/grpc.ClientStreamingServer[github.com/hashicorp/terraform/internal/rpcapi/terraform1/stacks.OpenStackPlan_RequestItem, github.com/hashicorp/terraform/internal/rpcapi/terraform1/stacks.OpenStackPlan_Response]

		// Pull grpc.Foobar out of fullType string
		grpcGenericRe := regexp.MustCompile(`^google\.golang\.org\/(?P<GrpcType>grpc\.\w+)\[github.com`)
		i := grpcGenericRe.SubexpIndex("GrpcType")
		grpcGeneric := grpcGenericRe.FindStringSubmatch(fullType)[i]

		// Get type argument(s)
		typeRe := regexp.MustCompile(fmt.Sprintf(`%s\.\w+`, name))
		typeArgs := typeRe.FindAllString(fullType, -1)

		// Build string, with potential need for comma separation
		// e.g. grpc.Foobar[pkg1.A, pkg2.B]
		var buf strings.Builder
		buf.WriteString(grpcGeneric + "[")
		for i, arg := range typeArgs {
			buf.WriteString(arg)
			if i+1 != len(typeArgs) {
				buf.WriteString(", ")
			}
		}
		buf.WriteString("]")
		return buf.String()
	default:
		log.Fatalf("don't know what to do with parameter type %s", fullType)
		return ""
	}
}

var firstCapPattern = regexp.MustCompile("(.)([A-Z][a-z]+)")
var otherCapPattern = regexp.MustCompile("([a-z0-9])([A-Z])")

func toFilenameCase(typeName string) string {
	ret := firstCapPattern.ReplaceAllString(typeName, "${1}_${2}")
	ret = otherCapPattern.ReplaceAllString(ret, "${1}_${2}")
	return strings.ToLower(ret)
}
